"""Cross platform abstractions for inter-process communication

On Unix, this uses AF_UNIX sockets.
On Windows, this uses NamedPipes.
"""

import base64
import os
import shutil
import sys
import tempfile

from typing import Optional, Callable
from typing_extensions import Final, Type

from types import TracebackType

if sys.platform == 'win32':
    # This may be private, but it is needed for IPC on Windows, and is basically stable
    import _winapi
    import ctypes

    _IPCHandle = int

    kernel32 = ctypes.windll.kernel32
    DisconnectNamedPipe = kernel32.DisconnectNamedPipe  # type: Callable[[_IPCHandle], int]
    FlushFileBuffers = kernel32.FlushFileBuffers  # type: Callable[[_IPCHandle], int]
else:
    import socket
    _IPCHandle = socket.socket


class IPCException(Exception):
    """Exception for IPC issues."""
    pass


class IPCBase:
    """Base class for communication between the dmypy client and server.

    This contains logic shared between the client and server, such as reading
    and writing.
    """

    connection = None  # type: _IPCHandle

    def __init__(self, name: str, timeout: Optional[float]) -> None:
        self.name = name
        self.timeout = timeout

    def read(self, size: int = 100000) -> bytes:
        """Read bytes from an IPC connection until its empty."""
        bdata = bytearray()
        if sys.platform == 'win32':
            while True:
                ov, err = _winapi.ReadFile(self.connection, size, overlapped=True)
                # TODO: remove once typeshed supports Literal types
                assert isinstance(ov, _winapi.Overlapped)
                assert isinstance(err, int)
                try:
                    if err == _winapi.ERROR_IO_PENDING:
                        timeout = int(self.timeout * 1000) if self.timeout else _winapi.INFINITE
                        res = _winapi.WaitForSingleObject(ov.event, timeout)
                        if res != _winapi.WAIT_OBJECT_0:
                            raise IPCException("Bad result from I/O wait: {}".format(res))
                except BaseException:
                    ov.cancel()
                    raise
                _, err = ov.GetOverlappedResult(True)
                more = ov.getbuffer()
                if more:
                    bdata.extend(more)
                if err == 0:
                    # we are done!
                    break
                elif err == _winapi.ERROR_MORE_DATA:
                    # read again
                    continue
                elif err == _winapi.ERROR_OPERATION_ABORTED:
                    raise IPCException("ReadFile operation aborted.")
        else:
            while True:
                more = self.connection.recv(size)
                if not more:
                    break
                bdata.extend(more)
        return bytes(bdata)

    def write(self, data: bytes) -> None:
        """Write bytes to an IPC connection."""
        if sys.platform == 'win32':
            try:
                ov, err = _winapi.WriteFile(self.connection, data, overlapped=True)
                # TODO: remove once typeshed supports Literal types
                assert isinstance(ov, _winapi.Overlapped)
                assert isinstance(err, int)
                try:
                    if err == _winapi.ERROR_IO_PENDING:
                        timeout = int(self.timeout * 1000) if self.timeout else _winapi.INFINITE
                        res = _winapi.WaitForSingleObject(ov.event, timeout)
                        if res != _winapi.WAIT_OBJECT_0:
                            raise IPCException("Bad result from I/O wait: {}".format(res))
                    elif err != 0:
                        raise IPCException("Failed writing to pipe with error: {}".format(err))
                except BaseException:
                    ov.cancel()
                    raise
                bytes_written, err = ov.GetOverlappedResult(True)
                assert err == 0, err
                assert bytes_written == len(data)
            except WindowsError as e:
                raise IPCException("Failed to write with error: {}".format(e.winerror)) from e
        else:
            self.connection.sendall(data)
            self.connection.shutdown(socket.SHUT_WR)

    def close(self) -> None:
        if sys.platform == 'win32':
            if self.connection != _winapi.NULL:
                _winapi.CloseHandle(self.connection)
        else:
            self.connection.close()


class IPCClient(IPCBase):
    """The client side of an IPC connection."""

    def __init__(self, name: str, timeout: Optional[float]) -> None:
        super().__init__(name, timeout)
        if sys.platform == 'win32':
            timeout = int(self.timeout * 1000) if self.timeout else _winapi.NMPWAIT_WAIT_FOREVER
            try:
                _winapi.WaitNamedPipe(self.name, timeout)
            except FileNotFoundError as e:
                raise IPCException("The NamedPipe at {} was not found.".format(self.name)) from e
            except WindowsError as e:
                if e.winerror == _winapi.ERROR_SEM_TIMEOUT:
                    raise IPCException("Timed out waiting for connection.") from e
                else:
                    raise
            try:
                self.connection = _winapi.CreateFile(
                    self.name,
                    _winapi.GENERIC_READ | _winapi.GENERIC_WRITE,
                    0,
                    _winapi.NULL,
                    _winapi.OPEN_EXISTING,
                    _winapi.FILE_FLAG_OVERLAPPED,
                    _winapi.NULL,
                )
            except WindowsError as e:
                if e.winerror == _winapi.ERROR_PIPE_BUSY:
                    raise IPCException("The connection is busy.") from e
                else:
                    raise
            _winapi.SetNamedPipeHandleState(self.connection,
                                            _winapi.PIPE_READMODE_MESSAGE,
                                            None,
                                            None)
        else:
            self.connection = socket.socket(socket.AF_UNIX)
            self.connection.settimeout(timeout)
            self.connection.connect(name)

    def __enter__(self) -> 'IPCClient':
        return self

    def __exit__(self,
                 exc_ty: 'Optional[Type[BaseException]]' = None,
                 exc_val: Optional[BaseException] = None,
                 exc_tb: Optional[TracebackType] = None,
                 ) -> None:
        self.close()


class IPCServer(IPCBase):

    BUFFER_SIZE = 2**16  # type: Final

    def __init__(self, name: str, timeout: Optional[float] = None) -> None:
        if sys.platform == 'win32':
            name = r'\\.\pipe\{}-{}.pipe'.format(
                name, base64.urlsafe_b64encode(os.urandom(6)).decode())
        else:
            name = '{}.sock'.format(name)
        super().__init__(name, timeout)
        if sys.platform == 'win32':
            self.connection = _winapi.CreateNamedPipe(self.name,
                _winapi.PIPE_ACCESS_DUPLEX
                | _winapi.FILE_FLAG_FIRST_PIPE_INSTANCE
                | _winapi.FILE_FLAG_OVERLAPPED,
                _winapi.PIPE_READMODE_MESSAGE
                | _winapi.PIPE_TYPE_MESSAGE
                | _winapi.PIPE_WAIT
                | 0x8,  # PIPE_REJECT_REMOTE_CLIENTS
                1,  # one instance
                self.BUFFER_SIZE,
                self.BUFFER_SIZE,
                _winapi.NMPWAIT_WAIT_FOREVER,
                0,  # Use default security descriptor
                                                      )
            if self.connection == -1:  # INVALID_HANDLE_VALUE
                err = _winapi.GetLastError()
                raise IPCException('Invalid handle to pipe: {}'.format(err))
        else:
            self.sock_directory = tempfile.mkdtemp()
            sockfile = os.path.join(self.sock_directory, self.name)
            self.sock = socket.socket(socket.AF_UNIX)
            self.sock.bind(sockfile)
            self.sock.listen(1)
            if timeout is not None:
                self.sock.settimeout(timeout)

    def __enter__(self) -> 'IPCServer':
        if sys.platform == 'win32':
            # NOTE: It is theoretically possible that this will hang forever if the
            # client never connects, though this can be "solved" by killing the server
            try:
                ov = _winapi.ConnectNamedPipe(self.connection, overlapped=True)
                # TODO: remove once typeshed supports Literal types
                assert isinstance(ov, _winapi.Overlapped)
            except WindowsError as e:
                # Don't raise if the client already exists, or the client already connected
                if e.winerror not in (_winapi.ERROR_PIPE_CONNECTED, _winapi.ERROR_NO_DATA):
                    raise
            else:
                try:
                    timeout = int(self.timeout * 1000) if self.timeout else _winapi.INFINITE
                    res = _winapi.WaitForSingleObject(ov.event, timeout)
                    assert res == _winapi.WAIT_OBJECT_0
                except BaseException:
                    ov.cancel()
                    _winapi.CloseHandle(self.connection)
                    raise
                _, err = ov.GetOverlappedResult(True)
                assert err == 0
        else:
            try:
                self.connection, _ = self.sock.accept()
            except socket.timeout as e:
                raise IPCException('The socket timed out') from e
        return self

    def __exit__(self,
                 exc_ty: 'Optional[Type[BaseException]]' = None,
                 exc_val: Optional[BaseException] = None,
                 exc_tb: Optional[TracebackType] = None,
                 ) -> None:
        if sys.platform == 'win32':
            try:
                # Wait for the client to finish reading the last write before disconnecting
                if not FlushFileBuffers(self.connection):
                    raise IPCException("Failed to flush NamedPipe buffer,"
                                       "maybe the client hung up?")
            finally:
                DisconnectNamedPipe(self.connection)
        else:
            self.close()

    def cleanup(self) -> None:
        if sys.platform == 'win32':
            self.close()
        else:
            shutil.rmtree(self.sock_directory)

    @property
    def connection_name(self) -> str:
        if sys.platform == 'win32':
            return self.name
        else:
            return self.sock.getsockname()
